// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_MATRIXPROCESSOR_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_MATRIXPROCESSOR_H_

#include <mc_connections.h>
#include <systemc.h>

#include "Skewer.h"
#include "SystolicArray.h"
#include "src/AccelTypes.h"

template <typename IDTYPE, typename WDTYPE, typename ODTYPE, int NROWS,
          int NCOLS, int BUFFER_SIZE>
SC_MODULE(MatrixProcessor) {
 private:
  Connections::SyncChannel CCS_INIT_S1(weightLoadDone);
  Connections::SyncChannel CCS_INIT_S1(weightSwapDone);

  SerializedSkewer<IDTYPE, NROWS> CCS_INIT_S1(inputSkewer);
  Connections::Combinational<Pack1D<IDTYPE, NROWS> > CCS_INIT_S1(
      inputSkewerDin);

  SerializedSkewer<ac_int<1, false>, NROWS> CCS_INIT_S1(weightSwapSkewer);
  Connections::Combinational<Pack1D<ac_int<1, false>, NROWS> > CCS_INIT_S1(
      weightSwapSkewerDin);

  SerializedSkewer<ODTYPE, NROWS> CCS_INIT_S1(psumInSkewer);
  Connections::Combinational<Pack1D<ODTYPE, NROWS> > CCS_INIT_S1(
      psumInSkewerDin);

  DeserializedSkewer<ODTYPE, NROWS> CCS_INIT_S1(psumOutSkewer);
  Connections::Combinational<Pack1D<ODTYPE, NROWS> > CCS_INIT_S1(
      psumOutSkewerDout);

  SystolicArray<IDTYPE, WDTYPE, ODTYPE, NROWS, NCOLS> CCS_INIT_S1(
      systolicArray);

 public:
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<Pack1D<IDTYPE, NROWS> > CCS_INIT_S1(inputsChannel);
  Connections::In<Pack1D<WDTYPE, NCOLS> > CCS_INIT_S1(weightsChannel);
  Connections::Out<Pack1D<ODTYPE, NCOLS> > CCS_INIT_S1(outputsChannel);
  Connections::In<Params> CCS_INIT_S1(paramsIn);

  Connections::Combinational<IDTYPE> inputsToSystolicArray[NROWS];
  Connections::Combinational<ac_int<1, false> >
      weightSwapToSystolicArray[NROWS];
  Connections::Combinational<ODTYPE> psumsToSystolicArray[NCOLS];
  Connections::Combinational<ODTYPE> outputsFromSystolicArray[NCOLS];
  Connections::Combinational<Pack1D<IDTYPE, NCOLS> > CCS_INIT_S1(
      weightsToSystolicArray);

  SC_CTOR(MatrixProcessor) {
    systolicArray.clk(clk);
    systolicArray.rstn(rstn);
    for (int i = 0; i < NROWS; i++) {
      systolicArray.inputs[i](inputsToSystolicArray[i]);
    }
    for (int i = 0; i < NROWS; i++) {
      systolicArray.swapWeights[i](weightSwapToSystolicArray[i]);
    }
    for (int i = 0; i < NCOLS; i++) {
      systolicArray.psums[i](psumsToSystolicArray[i]);
    }
    for (int i = 0; i < NCOLS; i++) {
      systolicArray.outputs[i](outputsFromSystolicArray[i]);
    }
    systolicArray.weights(weightsToSystolicArray);
    systolicArray.weightSwapDone(weightSwapDone);

    inputSkewer.clk(clk);
    inputSkewer.rstn(rstn);
    inputSkewer.din(inputSkewerDin);
    for (int i = 0; i < NROWS; i++) {
      inputSkewer.dout[i](inputsToSystolicArray[i]);
    }

    psumInSkewer.clk(clk);
    psumInSkewer.rstn(rstn);
    psumInSkewer.din(psumInSkewerDin);
    for (int i = 0; i < NROWS; i++) {
      psumInSkewer.dout[i](psumsToSystolicArray[i]);
    }

    psumOutSkewer.clk(clk);
    psumOutSkewer.rstn(rstn);
    for (int i = 0; i < NCOLS; i++) {
      psumOutSkewer.din[i](outputsFromSystolicArray[i]);
    }
    psumOutSkewer.dout(psumOutSkewerDout);

    weightSwapSkewer.clk(clk);
    weightSwapSkewer.rstn(rstn);
    weightSwapSkewer.din(weightSwapSkewerDin);
    for (int i = 0; i < NCOLS; i++) {
      weightSwapSkewer.dout[i](weightSwapToSystolicArray[i]);
    }

    SC_THREAD(process_weights);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void process_weights() {
    weightsChannel.Reset();
    weightsToSystolicArray.ResetWrite();
    weightLoadDone.ResetWrite();
    weightSwapDone.ResetRead();

    wait();

    while (true) {
#pragma hls_pipeline_init_interval 1
      for (int weight_count = 0; weight_count < NROWS; weight_count++) {
        Pack1D<WDTYPE, NCOLS> arrayWeights = weightsChannel.Pop();
        weightsToSystolicArray.Push(arrayWeights);
      }
      weightLoadDone.SyncPush();

      // wait for swap to propagate throughout the entire array before filling
      weightSwapDone.SyncPop();
    }
  }

  void run() {
    paramsIn.Reset();

    inputSkewerDin.ResetWrite();
    inputsChannel.Reset();
    weightLoadDone.ResetRead();
    psumInSkewerDin.ResetWrite();
    outputsChannel.Reset();
    weightSwapSkewerDin.ResetWrite();
    psumOutSkewerDout.ResetRead();

    wait();

    while (true) {
      Params params = paramsIn.Pop();

      int loop_counters[2][6];
      int loop_counters_out[2][6];

#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
#pragma hls_unroll yes
        for (int j = 0; j < 6; j++) {
          loop_counters[i][j] = 0;
          loop_counters_out[i][j] = 0;
        }
      }

      int totalOps = 1;
#pragma hls_unroll yes
      for (int i = 0; i < 2; i++) {
#pragma hls_unroll yes
        for (int j = 0; j < 6; j++) {
          totalOps *= params.loops[i][j];
        }
      }

      Pack1D<ODTYPE, NCOLS> accumulation_buffer[BUFFER_SIZE];

      int step = 0;
      int outputStep = 0;

      // Push inputs and psums into the systolic array
      // Pipelined across tiles

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      while (step < totalOps) {
#ifndef __SYNTHESIS__
        if (step % 1000 == 0) {
          CCS_LOG("step " << step << " out of " << totalOps);
        }
#endif

        // ensure all weights are ready for the next iteration
        Pack1D<ac_int<1, false>, NROWS> weightSwap;
        bool newWeights = loop_counters[1][params.weightReuseIndex[0]] == 0 &&
                          loop_counters[1][params.weightReuseIndex[1]] == 0;

        if (newWeights && step < totalOps) {
          // wait for weight loading to finish
          weightLoadDone.SyncPop();

#pragma hls_unroll yes
          for (int i = 0; i < NROWS; i++) {
            weightSwap.value[i] = true;
          }
        } else {
#pragma hls_unroll yes
          for (int i = 0; i < NROWS; i++) {
            weightSwap.value[i] = false;
          }
        }

        Pack1D<IDTYPE, NROWS> inputs;
        if (step < totalOps) {
          inputs = inputsChannel.Pop();
        }

        inputSkewerDin.Push(inputs);
        weightSwapSkewerDin.Push(weightSwap);

        Pack1D<ODTYPE, NCOLS> psum;
#pragma hls_unroll yes
        for (int i = 0; i < NCOLS; i++) {
          psum.value[i] = 0;
        }

        bool firstAccumulation =
            loop_counters[1][params.reductionLoopIndex[1]] == 0 &&
            loop_counters[1][params.fxIndex] == 0 &&
            loop_counters[1][params.fyIndex] == 0;

        if ((!firstAccumulation) && step < totalOps) {
          int readAddress = loop_counters[1][params.weightLoopIndex[1]] *
                                params.loops[1][params.inputXLoopIndex[1]] *
                                params.loops[1][params.inputYLoopIndex[1]] +
                            loop_counters[1][params.inputYLoopIndex[1]] *
                                params.loops[1][params.inputXLoopIndex[1]] +
                            loop_counters[1][params.inputXLoopIndex[1]];
#ifdef __SYNTHESIS__
        READ_ACC_BUFFER:
#endif
          psum = accumulation_buffer[readAddress];
        }

        psumInSkewerDin.Push(psum);

        Pack1D<ODTYPE, NCOLS> outputs;
        if (psumOutSkewerDout.PopNB(outputs)) {
          outputStep++;

          // Write to accumulation buffer or to output channel if
          // accumulation is complete
          bool accumulationFinished =
              (loop_counters_out[1][params.reductionLoopIndex[1]] ==
               params.loops[1][params.reductionLoopIndex[1]] - 1) &&
              (loop_counters_out[1][params.fxIndex] ==
               params.loops[1][params.fxIndex] - 1) &&
              (loop_counters_out[1][params.fyIndex] ==
               params.loops[1][params.fyIndex] - 1);

          if (accumulationFinished) {
            outputsChannel.Push(outputs);
          } else {
            int writeAddress = loop_counters_out[1][params.weightLoopIndex[1]] *
                                   params.loops[1][params.inputXLoopIndex[1]] *
                                   params.loops[1][params.inputYLoopIndex[1]] +
                               loop_counters_out[1][params.inputYLoopIndex[1]] *
                                   params.loops[1][params.inputXLoopIndex[1]] +
                               loop_counters_out[1][params.inputXLoopIndex[1]];
#ifdef __SYNTHESIS__
          WRITE_ACC_BUFFER:
#endif
            accumulation_buffer[writeAddress] = outputs;
          }

          loop_counters_out[1][5]++;
#pragma hls_unroll yes
          for (int i = 1; i >= 0; i--) {
#pragma hls_unroll yes
            for (int j = 5; j >= 0; j--) {
              if (loop_counters_out[i][j] == params.loops[i][j]) {
                loop_counters_out[i][j] = 0;
                if (j > 0) {
                  loop_counters_out[i][j - 1]++;
                } else {
                  if (i > 0) {
                    loop_counters_out[i - 1][5]++;
                  }
                }
              }
            }
          }
        }

        step++;
        loop_counters[1][5]++;
#pragma hls_unroll yes
        for (int i = 1; i >= 0; i--) {
#pragma hls_unroll yes
          for (int j = 5; j >= 0; j--) {
            if (loop_counters[i][j] == params.loops[i][j]) {
              loop_counters[i][j] = 0;
              if (j > 0) {
                loop_counters[i][j - 1]++;
              } else {
                if (i > 0) {
                  loop_counters[i - 1][5]++;
                }
              }
            }
          }
        }
      }

// Drain out any remaining outputs
#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
      while (outputStep < totalOps) {
        Pack1D<ODTYPE, NCOLS> outputs;
        if (psumOutSkewerDout.PopNB(outputs)) {
          outputStep++;
          outputsChannel.Push(outputs);

          loop_counters_out[1][5]++;
#pragma hls_unroll yes
          for (int i = 1; i >= 0; i--) {
#pragma hls_unroll yes
            for (int j = 5; j >= 0; j--) {
              if (loop_counters_out[i][j] == params.loops[i][j]) {
                loop_counters_out[i][j] = 0;
                if (j > 0) {
                  loop_counters_out[i][j - 1]++;
                } else {
                  if (i > 0) {
                    loop_counters_out[i - 1][5]++;
                  }
                }
              }
            }
          }
        }
        wait();
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_MATRIXPROCESSOR_H_
